# run_pm_da_mcts_walker2d.py
# ------------------------------------------------------------
# Power-Mean Dimension-Adaptive MCTS (PM-DA-MCTS) for Walker2d
# - ε-net hierarchy with k(n) = min{k: n <= |N_k|^2}
# - Polynomial exploration bonus: C * N_s^(1/4) / N_{s,a}^(1/2)
# - Power-mean backups with parameter p=POWER
# - Compatible run_seed(...) signature for existing tooling
# ------------------------------------------------------------

import math
import gym
import random
import numpy as np
import argparse
import statistics
import pickle
from dataclasses import dataclass, field
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import gym
# ADD THIS to trigger registration:
import improved_walker2d  # noqa: F401  (ensures ImprovedWalker2d-v0 is registered)

from SnapshotENV import SnapshotEnv  # your snapshot wrapper

DISCOUNT = 0.99


# --------- helpers to read SnapshotEnv result (supports dict or obj) ----------
def _get_attr(obj, name, default=None):
    if isinstance(obj, dict):
        return obj.get(name, default)
    return getattr(obj, name, default)


# ---------------------- Node/Child statistics structures ----------------------

@dataclass
class ActionStats:
    """Per-(state, action) statistics for power-mean backups."""
    visits: int = 0
    mean_q: float = 0.0          # running average Q
    sum_p: float = 0.0           # Σ Q^p for power-mean
    # optional: could track second moment for diagnostics


@dataclass
class GridLevel:
    """A lazily built ε-net (uniform grid) and its action stats."""
    epsilon: float
    actions: List[Tuple[float, ...]] = field(default_factory=list)
    stats: Dict[Tuple[float, ...], ActionStats] = field(default_factory=dict)


@dataclass
class Node:
    """Tree node with ε-net hierarchy and power-mean value."""
    snapshot: bytes                       # pickled env state
    depth: int
    done: bool
    visit_count: int = 0
    grids: Dict[int, GridLevel] = field(default_factory=dict)

    # Cached value estimate via power-mean
    def value_power_mean(self, p: float) -> float:
        if self.visit_count == 0:
            return 0.0
        # Aggregate across *all visited actions in all active grids* by frequency
        total_w = 0
        agg = 0.0
        for g in self.grids.values():
            for st in g.stats.values():
                if st.visits > 0:
                    w = st.visits
                    total_w += w
                    if abs(p - 1.0) < 1e-8:
                        agg += w * st.mean_q
                    else:
                        agg += w * (st.mean_q ** p)
        if total_w == 0:
            return 0.0
        if abs(p - 1.0) < 1e-8:
            return agg / total_w
        return (agg / total_w) ** (1.0 / p)


# ----------------------- ε-net construction utilities ------------------------

def build_uniform_grid(low: float, high: float, dim: int, epsilon: float, cap: int) -> List[Tuple[float, ...]]:
    """
    Build a uniform grid over [low, high]^dim with spacing about epsilon.
    Hard-caps the number of actions at 'cap' for safety.
    """
    span = high - low
    if epsilon <= 0:
        epsilon = span  # degenerate -> single point
    # points per axis; +1 so endpoints included
    per_axis = max(2, int(math.ceil(span / max(1e-12, epsilon))) + 1)
    # total grid size before cap
    total = per_axis ** dim
    # If exceeding cap, reduce per_axis to respect cap
    if total > cap:
        # find per_axis' such that (per_axis')^dim <= cap
        per_axis = int(max(2, math.floor(cap ** (1.0 / dim))))
        total = per_axis ** dim

    # evenly spaced points
    if per_axis <= 1:
        # fallback
        return [tuple([0.5 * (low + high)] * dim)]
    coords_1d = np.linspace(low, high, num=per_axis)
    grid = np.stack(np.meshgrid(*([coords_1d] * dim), indexing="xy"), axis=-1).reshape(-1, dim)
    # truncate to cap if needed
    if len(grid) > cap:
        grid = grid[:cap]
    return [tuple(map(float, a)) for a in grid]


def k_from_visits(n_visits: int, grid_sizes: List[int]) -> int:
    """
    Return smallest k index such that n_visits <= |N_k|^2.
    grid_sizes[k] is |N_k|.
    If none satisfies (e.g., all too small), return last index.
    """
    if n_visits <= 0:
        return 0
    for k, sz in enumerate(grid_sizes):
        if n_visits <= (sz * sz):
            return k
    return len(grid_sizes) - 1


# ----------------------- PM-DA-MCTS selection / backup -----------------------

def poly_bonus(C: float, Ns: int, Na: int) -> float:
    """
    Polynomial exploration bonus: C * Ns^(1/4) / Na^(1/2).
    Use Na>=1 in selection; make sure to initialize unseen actions first.
    """
    Ns = max(1, Ns)
    Na = max(1, Na)
    return C * (Ns ** 0.25) / (Na ** 0.5)


def select_action_from_grid(grid: GridLevel,
                            node_visits: int,
                            C: float, L_holder: float, beta: float,
                            eps1: float) -> Tuple[float, ...]:
    """
    Choose action using optimistic score:
    mean_Q(a) + L * epsilon^beta + C * Ns^(1/4)/Na^(1/2)
    with ε-greedy tie-break/randomization controlled by eps1.
    Unvisited actions are forced to be tried first.
    """
    # First, pick any unvisited action immediately (initialization)
    unvisited = [a for a, st in grid.stats.items() if st.visits == 0]
    if unvisited:
        return random.choice(unvisited)

    # Otherwise compute scores
    # Optionally add small randomization with eps1
    if random.random() < eps1:
        return random.choice(list(grid.stats.keys()))

    best_a = None
    best_s = -1e18
    bias = L_holder * (grid.epsilon ** beta)
    for a, st in grid.stats.items():
        bonus = poly_bonus(C, node_visits, st.visits)
        score = st.mean_q + bias + bonus
        if score > best_s:
            best_s = score
            best_a = a
    return best_a


def pmda_selection(env_plan: SnapshotEnv,
                   node: Node,
                   depth_limit: int,
                   p_power: float,
                   act_low: float, act_high: float, act_dim: int,
                   eps1: float,
                   C: float,
                   L_holder: float,
                   beta_holder: float,
                   eps1_grid: float,
                   cap_children: int) -> float:
    """
    Single selection/expansion/backup step from 'node'. Returns backed-up value to parent.
    Uses ε-net hierarchy + polynomial exploration + power-mean backups.
    """
    if node.done or node.depth >= depth_limit:
        return 0.0

    # ensure root snapshot is ready
    if node.snapshot is None:
        node.snapshot = env_plan.get_snapshot()

    # Build ε-net sequence lazily for a few levels so that k(n) has options.
    # We generate on demand: the exact level k we'll use is based on current visit_count.
    # Define epsilon_k = eps1_grid * 2^{-(k)/(d+2β)}, k=0,1,2,...
    # We'll prepare up to, say, KMAX=8 levels (can increase if needed).
    d = act_dim
    KMAX = 8
    eps_seq = [eps1_grid * (2.0 ** (-(k) / (d + 2.0 * beta_holder))) for k in range(KMAX)]

    # Prepare per-level grids; record sizes to compute k(n)
    sizes: List[int] = []
    for k, eps in enumerate(eps_seq):
        if k not in node.grids:
            actions = build_uniform_grid(act_low, act_high, d, eps, cap_children)
            st_map = {a: ActionStats() for a in actions}
            node.grids[k] = GridLevel(epsilon=eps, actions=actions, stats=st_map)
        sizes.append(len(node.grids[k].actions))

    # pick level via k(n): smallest k s.t. n <= |N_k|^2
    k_idx = k_from_visits(node.visit_count, sizes)
    grid = node.grids[k_idx]

    # choose action
    a = select_action_from_grid(
        grid=grid, node_visits=node.visit_count,
        C=C, L_holder=L_holder, beta=beta_holder, eps1=eps1
    )

    # get next-state result by branching from snapshot
    res = env_plan.get_result(node.snapshot, a)
    snap_next = _get_attr(res, "snapshot")
    rew = float(_get_attr(res, "reward", 0.0))
    done = bool(_get_attr(res, "is_done", False))

    # child node (transient; we don’t store tree to keep memory small)
    if snap_next is None:
        # If snapshot missing for any reason, fallback to immediate reward only
        q_target = rew
    else:
        # Recurse
        child = Node(snapshot=snap_next, depth=node.depth + 1, done=done)
        tail = pmda_selection(
            env_plan, child, depth_limit, p_power,
            act_low, act_high, act_dim, eps1, C, L_holder, beta_holder, eps1_grid, cap_children
        )
        q_target = rew + DISCOUNT * tail

    # backup into node
    node.visit_count += 1
    st = grid.stats[a]
    st.visits += 1
    # running average update
    st.mean_q += (q_target - st.mean_q) / st.visits
    # power accumulator (holds Σ (mean_q^p) *only if you want per-sample; here we store via mean*)
    # We'll keep sum_p consistent with mean trajectory so diagnostics remain stable.
    st.sum_p += (st.mean_q ** p_power)

    return q_target


# --------------------------- Top-level evaluation -----------------------------

def run_seed(env_name: str,
             iterations: int,
             seed: int,
             C: float,
             eps1: float,
             L: float,
             frac: float,
             POWER: float,
             cap: int,
             test_horizon: int = 150,
             plan_depth: int = 100,
             beta_holder: float = 1.0) -> float:
    """
    PM-DA-MCTS episode with replanning each step.
    Args map to the paper:
      - C: exploration constant in polynomial bonus
      - eps1: ε-greedy over grid actions (set 0.0 for pure optimistic policy)
      - L: Hölder constant (discretization bias term L * ε_k^β)
      - frac: ε_1 (coarsest ε-net radius / grid spacing proxy)
      - POWER: power parameter p for power-mean backups
      - cap: hard cap on actions per node (safety)
    """
    random.seed(seed)
    np.random.seed(seed)

    # snapshot-capable planning env + root snapshot
    plan_env = SnapshotEnv(gym.make(env_name).env)
    plan_env.reset()
    root_snap = plan_env.get_snapshot()

    # action space info (Walker2d: 6-D in [-1,1])
    act_low, act_high, act_dim = -1.0, 1.0, 6

    # initial planning at root
    root = Node(snapshot=root_snap, depth=0, done=False)
    for _ in range(iterations):
        pmda_selection(plan_env, root, plan_depth, POWER,
                       act_low, act_high, act_dim,
                       eps1, C, L, beta_holder, frac, int(cap))

    # test rollout from root snapshot, step-wise replan
    test_env = pickle.loads(root_snap)
    total, df = 0.0, 1.0
    for _ in range(test_horizon):
        # choose action from current root (best by mean_q among current grid level)
        # Use k(n) at root to pick a consistent grid to read from:
        # (If none visited, sample random action.)
        if root.visit_count == 0:
            a_exec = tuple(random.uniform(act_low, act_high) for _ in range(act_dim))
        else:
            # compute current grid index and pick best mean_Q action (greedy)
            d = act_dim
            KMAX = 8
            eps_seq = [frac * (2.0 ** (-(k) / (d + 2.0 * beta_holder))) for k in range(KMAX)]
            sizes = []
            for k, eps in enumerate(eps_seq):
                if k not in root.grids:
                    actions = build_uniform_grid(act_low, act_high, d, eps, int(cap))
                    st_map = {a: ActionStats() for a in actions}
                    root.grids[k] = GridLevel(epsilon=eps, actions=actions, stats=st_map)
                sizes.append(len(root.grids[k].actions))
            k_idx = k_from_visits(root.visit_count, sizes)
            g = root.grids[k_idx]
            # If no visits yet in this grid, just random; otherwise best mean_q
            visited = [(a, st) for a, st in g.stats.items() if st.visits > 0]
            if not visited:
                a_exec = tuple(random.uniform(act_low, act_high) for _ in range(act_dim))
            else:
                a_exec = max(visited, key=lambda x: x[1].mean_q)[0]

        _, r, done, _ = test_env.step(a_exec)
        total += df * r
        df *= DISCOUNT
        if done:
            break

        # re-root: build new Node around current test state and replan
        next_snap = pickle.dumps(test_env)
        root = Node(snapshot=next_snap, depth=0, done=False)
        for _ in range(iterations):
            pmda_selection(plan_env, root, plan_depth, POWER,
                           act_low, act_high, act_dim,
                           eps1, C, L, beta_holder, frac, int(cap))

    test_env.close()
    return float(total)


# ------------------------------ CLI experiment --------------------------------

def run_many(env: str,
             iters: List[int],
             seeds: int,
             C: float, eps1: float, L: float, frac: float, POWER: float, cap_children: int,
             test_horizon: int, plan_depth: int, beta_holder: float,
             out_txt: str, out_csv: str):
    Path(out_txt).parent.mkdir(parents=True, exist_ok=True)
    Path(out_csv).parent.mkdir(parents=True, exist_ok=True)

    with open(out_txt, "a") as ftxt, open(out_csv, "w") as fcsv:
        fcsv.write("env,iter,seeds,C,eps1,L,frac,POWER,cap_children,mean,std\n")
        for it in iters:
            vals = []
            for sd in range(seeds):
                v = run_seed(env, it, sd, C, eps1, L, frac, POWER, cap_children,
                             test_horizon=test_horizon, plan_depth=plan_depth, beta_holder=beta_holder)
                vals.append(v)
            mean_v = statistics.mean(vals)
            std_v = statistics.pstdev(vals) if len(vals) > 1 else 0.0
            msg = (f"Env={env}, ITER={it}: Mean={mean_v:.3f} ± {2.0*std_v:.3f} "
                   f"(over {seeds} seeds) [C={C}, eps1={eps1}, L={L}, eps1_grid={frac}, "
                   f"p={POWER}, cap={cap_children}, beta={beta_holder}]")
            print(msg)
            ftxt.write(msg + "\n"); ftxt.flush()
            fcsv.write(f"{env},{it},{seeds},{C},{eps1},{L},{frac},{POWER},{cap_children},{mean_v:.6f},{std_v:.6f}\n")
            fcsv.flush()
    print(f"Done. Wrote {out_txt} and {out_csv}.")


def parse_args():
    ap = argparse.ArgumentParser()
    ap.add_argument("--env", type=str, default="ImprovedWalker2d-v0")
    ap.add_argument("--iters", nargs="+", type=int, default=[3, 4, 7, 11, 18, 29])
    ap.add_argument("--seeds", type=int, default=20)
    ap.add_argument("--test_horizon", type=int, default=150)
    ap.add_argument("--plan_depth", type=int, default=100)
    # Paper parameters (mapped to your usual flags)
    ap.add_argument("--C", type=float, default=16.0, help="Exploration constant in polynomial bonus")
    ap.add_argument("--eps1", type=float, default=0.0, help="ε-greedy over grid actions (0=off)")
    ap.add_argument("--L", type=float, default=1.0, help="Hölder constant L for L * ε_k^β")
    ap.add_argument("--frac", type=float, default=0.5, help="ε_1: coarsest grid spacing proxy")
    ap.add_argument("--POWER", type=float, default=2.0, help="Power p for power-mean backups")
    ap.add_argument("--cap_children", type=int, default=64, help="Hard cap on actions per node")
    ap.add_argument("--beta", type=float, default=1.0, help="Hölder β for ε_k schedule")
    ap.add_argument("--out_txt", type=str, default="pmda_walker2d.txt")
    ap.add_argument("--out_csv", type=str, default="pmda_walker2d.csv")
    return ap.parse_args()


def main():
    args = parse_args()
    print("[PM-DA-MCTS] Settings:")
    print(f" env={args.env}")
    print(f" iters={args.iters} seeds={args.seeds} test_horizon={args.test_horizon} plan_depth={args.plan_depth}")
    print(f" C={args.C} eps1={args.eps1} L={args.L} eps1_grid={args.frac} p={args.POWER} cap={args.cap_children} beta={args.beta}")
    run_many(args.env, args.iters, args.seeds,
             args.C, args.eps1, args.L, args.frac, args.POWER, args.cap_children,
             args.test_horizon, args.plan_depth, args.beta,
             args.out_txt, args.out_csv)


if __name__ == "__main__":
    main()
